{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# {class}`~tvm.contrib.msc.framework.tvm.runtime.TVMRunner`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm-book/doc/tutorials/msc\n"
     ]
    }
   ],
   "source": [
    "%cd ..\n",
    "from pathlib import Path\n",
    "\n",
    "temp_dir = Path(\".temp\")\n",
    "temp_dir.mkdir(exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "构建前端模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch import fx\n",
    "from tvm.relax.frontend.torch import from_fx\n",
    "\n",
    "def _get_torch_model(name, training=False):\n",
    "    \"\"\"Get model from torch vision\"\"\"\n",
    "\n",
    "    # pylint: disable=import-outside-toplevel\n",
    "    try:\n",
    "        import torchvision\n",
    "\n",
    "        model = getattr(torchvision.models, name)()\n",
    "        if training:\n",
    "            model = model.train()\n",
    "        else:\n",
    "            model = model.eval()\n",
    "        return model\n",
    "    except:  # pylint: disable=bare-except\n",
    "        print(\"please install torchvision package\")\n",
    "        return None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构建并运行 `TVMRunner`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.contrib.msc.framework.tvm.runtime import TVMRunner\n",
    "import tvm\n",
    "from tvm.contrib.msc.core import utils as msc_utils\n",
    "\n",
    "def _test_from_torch(runner_cls, device, training=False, atol=1e-1, rtol=1e-1):\n",
    "    \"\"\"Test runner from torch model\"\"\"\n",
    "\n",
    "    torch_model = _get_torch_model(\"resnet50\", training)\n",
    "    if torch_model:\n",
    "        path = f\"{temp_dir}/test_runner_torch_{runner_cls.__name__}_{device}\"\n",
    "        workspace = msc_utils.set_workspace(msc_utils.msc_dir(path, keep_history=False))\n",
    "        log_path = workspace.relpath(\"MSC_LOG\", keep_history=False)\n",
    "        msc_utils.set_global_logger(\"critical\", log_path)\n",
    "        input_info = [([1, 3, 224, 224], \"float32\")]\n",
    "        datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]\n",
    "        torch_datas = [torch.from_numpy(d) for d in datas]\n",
    "        graph_model = fx.symbolic_trace(torch_model)\n",
    "        if training:\n",
    "            input_info = [([tvm.tir.Var(\"bz\", \"int64\"), 3, 224, 224], \"float32\")]\n",
    "        with torch.no_grad():\n",
    "            golden = torch_model(*torch_datas)\n",
    "            mod = from_fx(graph_model, input_info)\n",
    "        runner = runner_cls(mod, device=device, training=training)\n",
    "        runner.build()\n",
    "        outputs = runner.run(datas, ret_type=\"list\")\n",
    "        golden = [msc_utils.cast_array(golden)]\n",
    "        workspace.destory()\n",
    "        for gol_r, out_r in zip(golden, outputs):\n",
    "            np.testing.assert_allclose(gol_r, msc_utils.cast_array(out_r), atol=atol, rtol=rtol)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "for training in [True, False]:\n",
    "    _test_from_torch(TVMRunner, \"cpu\", training=training)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "for training in [True, False]:\n",
    "    _test_from_torch(TVMRunner, \"cuda\", training=training)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xxx",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
